import argparse
import json
import os
import shutil
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from bit_subset_parity import BitSubsetParity
from bit_subset_parity_data_module import BitSubsetParityDataModule
from pytorch_lightning.loggers import TensorBoardLogger


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--compress_steps', type=int, default=-1)
    parser.add_argument('--compress_type', type=str, default='tokenizer')
    parser = Trainer.add_argparse_args(parser)
    parser = BitSubsetParity.add_model_specific_args(parser)
    parser = BitSubsetParityDataModule.add_data_module_specific_args(parser)
    return parser.parse_args() 


def main():
    pl.seed_everything(1234)
    args = parse_arguments()
    
    if args.compress_steps == -1:
        args.compress_steps = int(np.log2(args.num_of_bits) - 2)
    if args.compress_steps > np.log2(args.num_of_bits) - 2:
        raise ValueError(f"Compress steps {args.compress_steps} is larger than the maximum possible value {np.log2(args.num_of_bits) - 2}")
    
    data_module = BitSubsetParityDataModule(step_by_step=args.step_by_step,
                                            max_training_steps=args.max_steps * args.accumulate_grad_batches,
                                            num_of_bits=args.num_of_bits,
                                            compress_steps=args.compress_steps,
                                            train_batch_size=args.train_batch_size,
                                            eval_batch_size=args.eval_batch_size,
                                            eval_steps=args.eval_steps,
                                            num_workers=args.num_workers,
                                            seed=args.seed)
    data_module.prepare_data()
    data_module.setup(stage="fit")

    model = BitSubsetParity(step_by_step=args.step_by_step,
                            num_of_bits=args.num_of_bits,
                            width=args.width,
                            num_heads=args.num_heads,
                            depth=args.depth,
                            compress_steps=args.compress_steps,
                            compress_type=args.compress_type,
                            learning_rate=args.learning_rate,
                            warmup_steps=args.warmup_steps,
                            weight_decay=args.weight_decay,
                            evaluate_with_greedy_decoding=args.evaluate_with_greedy_decoding)

    step_type = "step_by_step" if args.step_by_step else "single_step"

    logger = TensorBoardLogger(save_dir=os.path.join(os.getcwd(), f"logs_{args.width}_{args.num_heads}_{args.depth}"), name="compress_steps", version=f"{step_type}_{args.num_of_bits}_bits_{args.compress_steps}_compress_{args.compress_type}_{args.width}_width_{args.num_heads}_heads_{args.depth}_depth")

    checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="accuracy/val", mode="max", save_last=False)
    lr_monitor = LearningRateMonitor(logging_interval='step')
    trainer = Trainer.from_argparse_args(args, logger=logger, callbacks=[checkpoint_callback, lr_monitor])
    trainer.fit(model, data_module)
    print("Training finished")
    last_model_test_results = trainer.test(model, datamodule=data_module, ckpt_path=None)
    best_model_test_results = trainer.test(model, datamodule=data_module, ckpt_path="best")
    with open(os.path.join(trainer.log_dir, "last_model_test_results.json"), 'w') as f:
        json.dump(last_model_test_results, f)
    with open(os.path.join(trainer.log_dir, "best_model_test_results.json"), 'w') as f:
        json.dump(best_model_test_results, f)
    shutil.rmtree(checkpoint_callback.dirpath)


if __name__ == '__main__':
    main()
